Load the necessary libraries and source the utility functions.

library(Survinng)
library(ggplot2)
library(cowplot)
library(dplyr)
library(tidyr)
library(SurvMetrics)
library(survival)
library(simsurv)
library(survminer)
library(torch)
library(viridis)
library(here)

# Load utility functions
source(here("utils/utils_nn_training.R"))
source(here("utils/utils_plotting.R"))


# Set figure path
fig_path <- here::here("figures_paper")
if (!file.exists(fig_path)) dir.create(fig_path)
fig <- function(x) here::here(fig_path, x)

Time-independent effects

Generate the data

Simulation setting: - \(10,000\) samples (\(9,500\) for training, \(500\) for testing) - No time-dependent effects - \(X_1 \sim \mathcal{N}(0,1)\) has a positive effect on the hazard -> negative effect on survival - \(X_2 \sim \mathcal{U}(0,1)\) has a stronger negative effect on the hazard -> positive effect on survival - \(X_3 \sim \mathcal{U}(-1,1)\) has no effect

set.seed(42)

# Simulate data
n <- 10000
x <- data.frame(x1 = rnorm(n), x2 = runif(n, 0, 1), x3 = runif(n, -1, 1))
simdat <- simsurv(dist = "weibull", lambdas = 0.1, gammas = 2.5, betas = c(x1 = 1.7, x2 = -2.4),
                  x = x, maxt = 7)
y <- simdat[, -1]
colnames(y)[1] <- "time"
dat <- cbind(y, x)

# Train/test
idx <- sample(n, 9500)
train <- dat[idx, ]
test <- dat[-idx, ]

Fit the models

ext_deephit <- fit_model("DeepHit", train, test)
ext_coxtime <- fit_model("CoxTime", train, test)
ext_deepsurv <- fit_model("DeepSurv", train, test)

Create Explainer

exp_deephit <- Survinng::explain(ext_deephit[[1]], data = test)
exp_coxtime <- Survinng::explain(ext_coxtime[[1]], data = test)
exp_deepsurv <- Survinng::explain(ext_deepsurv[[1]], data = test)

Performance Measures

# Define a function to reshape data into a wide matrix format
prepare_matrix <- function(data, id_col = "id", time_col = "time", pred_col = "pred") {
  wide_data <- data %>%
    pivot_wider(names_from = {{time_col}}, values_from = {{pred_col}})
  
  # Convert the specified column to row names and remove it from the data
  wide_data <- as.data.frame(wide_data)
  rownames(wide_data) <- wide_data[[id_col]]
  wide_data <- wide_data[, -which(names(wide_data) == id_col)]
  
  # Convert to a matrix
  as.matrix(wide_data)
}

# Prepare matrices
matrix_coxtime <- prepare_matrix(ext_coxtime$pred)
matrix_deepsurv <- prepare_matrix(ext_deepsurv$pred)
matrix_deephit <- prepare_matrix(ext_deephit$pred)[,-1]

# Define survival object
surv_obj <- Surv(test$time, test$status)

# Define time indices and sampled time
t_interest <- sort(unique(ext_coxtime$pred$time))
num_samples <- 100
indices <- round(seq(1, length(t_interest), length.out = num_samples))
sampled_t <- t_interest[indices]
deephit_t <- sort(unique(ext_deephit$pred$time))[-1]

# Sample matrices
sampled_matrix_coxtime <- matrix_coxtime[, indices]
sampled_matrix_deepsurv <- matrix_deepsurv[, indices]

# Calculate Brier scores in a single step
calculate_brier <- function(matrix, times, surv_obj) {
  sapply(1:ncol(matrix), function(i) Brier(surv_obj, pre_sp = matrix[, i], times[i]))
}
metrics_coxtime <- calculate_brier(sampled_matrix_coxtime, sampled_t, surv_obj)
metrics_deepsurv <- calculate_brier(sampled_matrix_deepsurv, sampled_t, surv_obj)
metrics_deephit <- calculate_brier(matrix_deephit, deephit_t, surv_obj)

# Combine results into a single data frame for plotting
combine_results <- function(metrics, times, model_name) {
  data.frame(time = times, BS = metrics, model = model_name)
}
data_coxtime <- combine_results(metrics_coxtime, sampled_t, "CoxTime")
data_deepsurv <- combine_results(metrics_deepsurv, sampled_t, "DeepSurv")
data_deephit <- combine_results(metrics_deephit, deephit_t, "DeepHit")
data_BS <- rbind(data_coxtime, data_deepsurv, data_deephit)

# Plot Brier scores
colorblind_palette <- c("CoxTime" = "#E69F00", "DeepSurv" = "#56B4E9", "DeepHit" = "#009E73")
brier_plot_tid <- ggplot() +
  geom_line(data = data_BS, aes(x = time, y = BS, color = model, linetype = model)) +
  geom_rug(data = test, aes(x = time), sides = "bl", linewidth = 0.5, alpha = 0.5) +
  geom_hline(yintercept = 0) +
  scale_color_manual(values = colorblind_palette) +  # Apply custom colors
  scale_linetype_manual(values = c("CoxTime" = "solid", "DeepSurv" = "dashed", "DeepHit" = "dotted")) + 
  labs(title = "", x = "Time", y = "Brier Score", color = NULL, linetype = NULL) +
  theme_minimal(base_size = 17) +
  theme(legend.position = "bottom")
brier_plot_tid 


# Save plot
ggsave(fig("sim_tid_brier_plot.pdf"), plot = brier_plot_tid, width = 7, height = 5)

# Calculate C-index and IBS for each model
calculate_cindex <- function(matrix, surv_obj, index) {
  Cindex(surv_obj, predicted = matrix[, index])
}
calculate_ibs <- function(matrix, times, surv_obj) {
  IBS(surv_obj, sp_matrix = matrix, times)
}
C_coxtime <- calculate_cindex(sampled_matrix_coxtime, surv_obj, 50)
C_deepsurv <- calculate_cindex(sampled_matrix_deepsurv, surv_obj, 50)
C_deephit <- calculate_cindex(matrix_deephit, surv_obj, 15)
IBS_coxtime <- calculate_ibs(sampled_matrix_coxtime, sampled_t, surv_obj)
IBS_deepsurv <- calculate_ibs(sampled_matrix_deepsurv, sampled_t, surv_obj)
IBS_deephit <- calculate_ibs(matrix_deephit[,-1], deephit_t[-1], surv_obj)

# Display results
res <- data.frame(
  model = c("CoxTime", "DeepSurv", "DeepHit"),
  C_index = c(C_coxtime, C_deepsurv, C_deephit),
  IBS = c(IBS_coxtime, IBS_deepsurv, IBS_deephit)
)
saveRDS(res, fig("sim_tid_performance.rds"))
res
#>      model  C_index      IBS
#> 1  CoxTime 0.809269 0.098946
#> 2 DeepSurv 0.809117 0.099033
#> 3  DeepHit 0.809022 0.141535

Survival Prediction

# Print instances of interest
tid_ids <- c(13, 387)
print(test[tid_ids, ])
#>           time status        x1        x2          x3
#> 343  2.6653596      1 -0.434617 0.1162303 -0.08053765
#> 7906 0.9577924      1  2.454611 0.2462072 -0.04249294

# Compute Vanilla Gradient
grad_cox <- surv_grad(exp_coxtime, target = "survival", instance = tid_ids, batch_size = 1e6)
grad_deephit <- surv_grad(exp_deephit, target = "survival", instance = tid_ids, batch_size = 1e6)
grad_deepsurv <- surv_grad(exp_deepsurv, target = "survival", instance = tid_ids, batch_size = 1e6)

# Plot survival predictions
surv_plot <- cowplot::plot_grid(
  plot_surv_pred(grad_cox) ,
  plot_surv_pred(grad_deephit),
  plot_surv_pred(grad_deepsurv),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"),
  label_x = 0.03,      
  label_size = 14) 
surv_plot


# Save plot
ggsave(fig("sim_tid_surv_plot.pdf"), plot = surv_plot, width = 8, height = 14)

Explainable AI

Gradient (Sensitivity)

# Plot attributions
grad_plot <- cowplot::plot_grid(
  plot_attribution(grad_cox, label = "Grad(t)") ,
  plot_attribution(grad_deephit, label = "Grad(t)"),
  plot_attribution(grad_deepsurv, label = "Grad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
grad_plot


# Save plot
ggsave(fig("sim_tid_grad_plot.pdf"), plot = grad_plot, width = 10, height = 14)
# Plot normalized attributions
grad_plot_norm <- cowplot::plot_grid(
  plot_attribution(grad_cox, normalize = TRUE, label = "Grad(t)"),
  plot_attribution(grad_deephit, normalize = TRUE, label = "Grad(t)"),
  plot_attribution(grad_deepsurv, normalize = TRUE, label = "Grad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
grad_plot_norm


# Save plot
ggsave(fig("sim_tid_grad_plot_norm.pdf"), plot = grad_plot_norm, width = 10, height = 14)

SmoothGrad (Sensitivity)

# Compute SmoothGrad
sg_cox <- surv_smoothgrad(exp_coxtime, target = "survival", instance = tid_ids, n = 50, noise_level = 0.1, batch_size = 1e6)
sg_deephit <- surv_smoothgrad(exp_deephit, target = "survival", instance = tid_ids, n = 50, noise_level = 0.1, batch_size = 1e6)
sg_deepsurv <- surv_smoothgrad(exp_deepsurv, target = "survival", instance = tid_ids, n = 50, noise_level = 0.1, batch_size = 1e6)

# Plot attributions
smoothgrad_plot <- cowplot::plot_grid(
  plot_attribution(sg_cox, label = "SG(t)"), 
  plot_attribution(sg_deephit, label = "SG(t)"), 
  plot_attribution(sg_deepsurv, label = "SG(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
smoothgrad_plot


# Save plot
ggsave(fig("sim_tid_smoothgrad_plot.pdf"), plot = smoothgrad_plot, width = 10, height = 14)
# Plot normalized attributions
smoothgrad_plot_norm <- cowplot::plot_grid(
  plot_attribution(grad_cox, normalize = TRUE, label = "SG(t)"),
  plot_attribution(grad_deephit, normalize = TRUE, label = "SG(t)"),
  plot_attribution(grad_deepsurv, normalize = TRUE, label = "SG(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
smoothgrad_plot_norm


# Save plot
ggsave(fig("sim_tid_smoothgrad_plot_norm.pdf"), plot = smoothgrad_plot_norm, width = 10, height = 14)

Gradient x Input

# Compute GradientxInput
gradin_cox <- surv_grad(exp_coxtime, instance = tid_ids, times_input = TRUE, batch_size = 1e6)
gradin_deephit <- surv_grad(exp_deephit, instance = tid_ids, times_input = TRUE, batch_size = 1e6)
gradin_deepsurv <- surv_grad(exp_deepsurv, instance = tid_ids, times_input = TRUE, batch_size = 1e6)

# Plot attributions
gradin_plot <- cowplot::plot_grid(
  plot_attribution(gradin_cox, label = "GxI(t)"), 
  plot_attribution(gradin_deephit, label = "GxI(t)"), 
  plot_attribution(gradin_deepsurv, label = "GxI(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gradin_plot


# Save plot
ggsave(fig("sim_tid_gradin_plot.pdf"), plot = gradin_plot, width = 10, height = 14)
# Plot attributions
grad_gradin_plot <- cowplot::plot_grid(
  plot_attribution(grad_deepsurv, label = "Grad(t)") ,
  plot_attribution(gradin_deepsurv, label = "GxI(t)"),
  nrow = 2, labels = c("DeepSurv", "DeepSurv"))
grad_gradin_plot


# Save plot
ggsave(fig("sim_tid_grad_gradin_plot.pdf"), plot = grad_gradin_plot, width = 10, height = 9)
# Plot attributions
gradin_plot_norm <- cowplot::plot_grid(
  plot_attribution(gradin_cox, normalize = TRUE, label = "GxI(t)"), 
  plot_attribution(gradin_deephit, normalize = TRUE, label = "GxI(t)"), 
  plot_attribution(gradin_deepsurv, normalize = TRUE, label = "GxI(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gradin_plot_norm


# Save plot
ggsave(fig("sim_tid_gradin_plot_norm.pdf"), plot = gradin_plot_norm, width = 10, height = 14)

SmoothGrad x Input

# Compute SmoothGradxInput
sgin_cox <- surv_smoothgrad(exp_coxtime, instance = tid_ids, n = 50, noise_level = 0.3,
                          times_input = TRUE, , batch_size = 1e6)
sgin_deephit <- surv_smoothgrad(exp_deephit, instance = tid_ids, n = 50, noise_level = 0.3,
                              times_input = TRUE, , batch_size = 1e6)
sgin_deepsurv <- surv_smoothgrad(exp_deepsurv, instance = tid_ids, n = 50, noise_level = 0.3,
                               times_input = TRUE, , batch_size = 1e6)

# Plot attributions
smoothgradin_plot <- cowplot::plot_grid(
  plot_attribution(sgin_cox, label = "SGxI(t)"), 
  plot_attribution(sgin_deephit, label = "SGxI(t)"), 
  plot_attribution(sgin_deepsurv, label = "SGxI(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
smoothgradin_plot


# Save plot
ggsave(fig("sim_tid_smoothgradin_plot.pdf"), plot = smoothgradin_plot, width = 10, height = 14)
# Plot attributions
smoothgradin_plot_norm <- cowplot::plot_grid(
  plot_attribution(sgin_cox, normalize = TRUE, label = "SGxI(t)"), 
  plot_attribution(sgin_deephit, normalize = TRUE, label = "SGxI(t)"), 
  plot_attribution(sgin_deepsurv, normalize = TRUE, label = "SGxI(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
smoothgradin_plot_norm


# Save plot
ggsave(fig("sim_tid_smoothgradin_plot_norm.pdf"), plot = smoothgradin_plot_norm, width = 10, height = 14)

IntegratedGradient

Zero baseline

# Compute IntegratedGradient with 0 baseline
x_ref <- matrix(c(0,0,0), nrow = 1)
ig0_cox <- surv_intgrad(exp_coxtime, instance = tid_ids, n = 50, x_ref = x_ref, batch_size = 1e6)
ig0_deephit <- surv_intgrad(exp_deephit, instance = tid_ids, n = 50, x_ref = x_ref, batch_size = 1e6)
ig0_deepsurv <- surv_intgrad(exp_deepsurv, instance = tid_ids, n = 50, x_ref = x_ref, batch_size = 1e6)

# Plot attributions
intgrad0_plot <- cowplot::plot_grid(
  plot_attribution(ig0_cox, label = "IntGrad(t)"), 
  plot_attribution(ig0_deephit, label = "IntGrad(t)"), 
  plot_attribution(ig0_deepsurv, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgrad0_plot 


# Save plot
ggsave(fig("sim_tid_intgrad0_plot.pdf"), plot = intgrad0_plot, width = 10, height = 14)
# Plot attributions
intgrad0_plot_comp <- cowplot::plot_grid(
  plot_attribution(ig0_cox, add_comp = TRUE, label = "IntGrad(t)"), 
  plot_attribution(ig0_deephit, add_comp = TRUE, label = "IntGrad(t)"), 
  plot_attribution(ig0_deepsurv, add_comp = TRUE, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgrad0_plot_comp 


# Save plot
ggsave(fig("sim_tid_intgrad0_plot_comp.pdf"), plot = intgrad0_plot_comp, width = 10, height = 14)
# Plot contributions
intgrad0_plot_contr <- cowplot::plot_grid(
  plot_contribution(ig0_cox, label = "IntGrad(t)"), 
  plot_contribution(ig0_deephit, label = "IntGrad(t)"), 
  plot_contribution(ig0_deepsurv, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgrad0_plot_contr 


# Save plot
ggsave(fig("sim_tid_intgrad0_plot_contr.pdf"), plot = intgrad0_plot_contr, width = 10, height = 14)
# Plot force
intgrad0_plot_force <- cowplot::plot_grid(
  plot_force(ig0_cox, upper_distance = 0.04, lower_distance = 0.04, label = "IntGrad(t)"), 
  plot_force(ig0_deephit, upper_distance = 0.02, lower_distance = 0, label = "IntGrad(t)"), 
  plot_force(ig0_deepsurv, upper_distance = 0.04, lower_distance = 0.04, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgrad0_plot_force 


# Save plot
ggsave(fig("sim_tid_intgrad0_plot_force.pdf"), plot = intgrad0_plot_force, width = 10, height = 14)

Mean baseline

# Compute IntegratedGradient with mean baseline
x_ref <- NULL # default: feature-wise mean
igm_cox <- surv_intgrad(exp_coxtime, instance = tid_ids, n = 50, x_ref = x_ref, batch_size = 1e6)
igm_deephit <- surv_intgrad(exp_deephit, instance = tid_ids, n = 50, x_ref = x_ref, batch_size = 1e6)
igm_deepsurv <- surv_intgrad(exp_deepsurv, instance = tid_ids, n = 50, x_ref = x_ref, batch_size = 1e6)

# Plot attributions
intgradmean_plot <- cowplot::plot_grid(
  plot_attribution(igm_cox, label = "IntGrad(t)"), 
  plot_attribution(igm_deephit, label = "IntGrad(t)"), 
  plot_attribution(igm_deepsurv, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgradmean_plot


# Save plot
ggsave(fig("sim_tid_intgradmean_plot.pdf"), plot = intgradmean_plot, width = 10, height = 14)
# Plot attributions
intgradmean_plot_comp <- cowplot::plot_grid(
  plot_attribution(igm_cox, add_comp = TRUE, label = "IntGrad(t)"), 
  plot_attribution(igm_deephit, add_comp = TRUE, label = "IntGrad(t)"), 
  plot_attribution(igm_deepsurv, add_comp = TRUE, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgradmean_plot_comp


# Save plot
ggsave(fig("sim_tid_intgradmean_plot_comp.pdf"), plot = intgradmean_plot_comp, width = 10, height = 14)
# Plot contributions
intgradm_plot_contr <- cowplot::plot_grid(
  plot_contribution(igm_cox, label = "IntGrad(t)"), 
  plot_contribution(igm_deephit, label = "IntGrad(t)"), 
  plot_contribution(igm_deepsurv, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgradm_plot_contr 


# Save plot
ggsave(fig("sim_tid_intgradm_plot_contr.pdf"), plot = intgradm_plot_contr, width = 10, height = 14)
# Plot force
intgradm_plot_force <- cowplot::plot_grid(
  plot_force(igm_cox, upper_distance = 0, lower_distance = 0, label = "IntGrad(t)"), 
  plot_force(igm_deephit, upper_distance = 0, lower_distance = 0, label = "IntGrad(t)"), 
  plot_force(igm_deepsurv, upper_distance = 0, lower_distance = 0, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgradm_plot_force


# Save plot
ggsave(fig("sim_tid_intgradm_plot_force.pdf"), plot = intgradm_plot_force, width = 10, height = 14)

GradShap

# Compute GradShap
gshap_cox <- surv_gradSHAP(exp_coxtime, instance = tid_ids, n = 50, num_samples = 100, 
                           batch_size = 1e6)
gshap_deephit <- surv_gradSHAP(exp_deephit, instance = tid_ids, n = 50, num_samples = 100, 
                           batch_size = 1e6)
gshap_deepsurv <- surv_gradSHAP(exp_deepsurv, instance = tid_ids, n = 50, num_samples = 100, 
                           batch_size = 1e6)

# Plot attributions
gshap_plot <- cowplot::plot_grid(
  plot_attribution(gshap_cox, label = "GradSHAP(t)"), 
  plot_attribution(gshap_deephit, label = "GradSHAP(t)"), 
  plot_attribution(gshap_deepsurv, label = "GradSHAP(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gshap_plot 


# Save plot
ggsave(fig("sim_tid_gshap_plot.pdf"), plot = gshap_plot, width = 10, height = 14)
# Plot attributions
gshap_plot_comp <- cowplot::plot_grid(
  plot_attribution(gshap_cox, add_comp = TRUE, label = "GradSHAP(t)"), 
  plot_attribution(gshap_deephit, add_comp = TRUE, label = "GradSHAP(t)"), 
  plot_attribution(gshap_deepsurv, add_comp = TRUE, label = "GradSHAP(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gshap_plot_comp


# Save plot
ggsave(fig("sim_tid_gshap_plot_comp.pdf"), plot = gshap_plot_comp, width = 10, height = 14)
# Plot contributions %
gshap_plot_contr <- cowplot::plot_grid(
  plot_contribution(gshap_cox, label = "GradSHAP(t)"), 
  plot_contribution(gshap_deephit, label = "GradSHAP(t)"), 
  plot_contribution(gshap_deepsurv, label = "GradSHAP(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gshap_plot_contr 


# Save plot
ggsave(fig("sim_tid_gshap_plot_contr.pdf"), plot = gshap_plot_contr, width = 10, height = 14)
# Plot force
gshap_plot_force <- cowplot::plot_grid(
  plot_force(gshap_cox, upper_distance = 0, lower_distance = 0, lower_distance_x1 = 0.05, label = "GradSHAP(t)"), 
  plot_force(gshap_deephit, upper_distance = 0, lower_distance = 0, lower_distance_x1 = 0.02, label = "GradSHAP(t)"), 
  plot_force(gshap_deepsurv, upper_distance = 0, lower_distance = 0, lower_distance_x1 = 0.04, label = "GradSHAP(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gshap_plot_force


# Save plot
ggsave(fig("sim_tid_gshap_plot_force.pdf"), plot = gshap_plot_force, width = 10, height = 14)